{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65b3f9c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.Manager import onekey_show\n",
    "\n",
    "onekey_show('人机对比')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "321dd425",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import shutil\n",
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import onekey_algo.custom.components as okcomp\n",
    "from onekey_algo import get_param_in_cwd\n",
    "\n",
    "plt.rcParams['figure.dpi'] = 300\n",
    "os.makedirs('img', exist_ok=True)\n",
    "# 获取配置\n",
    "task = get_param_in_cwd('task_column') or 'label'\n",
    "label_file = r'C:/Users/onekey/Desktop/onekey_comp/comp8-Modules/data/compare_with_human/label.csv'\n",
    "AI_file = r'C:/Users/onekey/Desktop/onekey_comp/comp8-Modules/data/compare_with_human/AI.csv'\n",
    "human_only_file = r'C:/Users/onekey/Desktop/onekey_comp/comp8-Modules/data/compare_with_human/human_only.csv'\n",
    "AI_assistant_file = r'C:/Users/onekey/Desktop/onekey_comp/comp8-Modules/data/compare_with_human/ai_assist.csv'\n",
    "\n",
    "# 读取label文件。\n",
    "labels = [task]\n",
    "label_data_ = pd.read_csv(label_file)\n",
    "label_data_ = label_data_.dropna(axis=1)\n",
    "\n",
    "ids = label_data_['ID']\n",
    "print(label_data_.columns)\n",
    "label_data = label_data_[['ID'] + labels]\n",
    "label_data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac5ee2cf",
   "metadata": {},
   "source": [
    "# 人机对比"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5002786",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "subset = 'train'\n",
    "ALL_results = pd.read_csv(AI_file)\n",
    "ALL_results.columns = ['ID', '-0', 'AI']\n",
    "human_results = pd.read_csv(human_only_file)\n",
    "model_names = ['AI'] + list(human_results.columns[1:])\n",
    "ALL_results = pd.merge(pd.merge(ALL_results, human_results, on='ID', how='inner'), label_data_)\n",
    "\n",
    "ALL_results = ALL_results.dropna(axis=1)\n",
    "ALL_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e04731e",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from matplotlib.markers import MarkerStyle\n",
    "\n",
    "gt = [np.array(ALL_results[task]) for d in model_names]\n",
    "pred_train = [np.array(ALL_results[d]) for d in model_names]\n",
    "okcomp.comp1.draw_roc(gt, pred_train, labels=model_names, title=f\"Model AUC\", auto_point=True)\n",
    "plt.savefig(f'img/Human_auc.svg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4826febd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.metrics import analysis_pred_binary\n",
    "metric = []\n",
    "for mname, y, score in zip(model_names, gt, pred_train):\n",
    "    # 计算验证集指标\n",
    "    acc, auc, ci, tpr, tnr, ppv, npv, precision, recall, f1, thres = analysis_pred_binary(y, score, use_youden=False)\n",
    "    ci = f\"{ci[0]:.4f} - {ci[1]:.4f}\"\n",
    "    metric.append((mname, acc, auc, ci, tpr, tnr, ppv, npv, precision, recall, f1, thres, f\"Without AI\"))\n",
    "a = pd.DataFrame(metric, index=None, columns=['Signature', 'Accuracy', 'AUC', '95% CI', 'Sensitivity', 'Specificity', \n",
    "                                          'PPV', 'NPV', 'Precision', 'Recall', 'F1','Threshold', 'Cohort'])\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e33db26a",
   "metadata": {},
   "outputs": [],
   "source": [
    "a['group'] = a['Signature'].map(lambda x: x if x == 'AI' else x[:-1])\n",
    "aa = a.groupby('group').mean().reset_index()\n",
    "aa['Cohort'] = 'Without AI'\n",
    "aa"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69899f07",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.delong import delong_roc_test\n",
    "from onekey_algo.custom.components.comp1 import draw_matrix\n",
    "\n",
    "delong = []\n",
    "delong_columns = []\n",
    "this_delong = []\n",
    "plt.figure(figsize=(15, 12))\n",
    "cm = np.zeros((len(model_names), len(model_names)))\n",
    "for i, mni in enumerate(model_names):\n",
    "    for j, mnj in enumerate(model_names):\n",
    "        if i <= j:\n",
    "            cm[i][j] = np.nan\n",
    "        else:\n",
    "            cm[i][j] = delong_roc_test(ALL_results[task], ALL_results[mni], ALL_results[mnj])[0][0]\n",
    "cm = pd.DataFrame(cm[1:, :-1], index=model_names[1:], columns=model_names[:-1])\n",
    "draw_matrix(cm, annot=True, cmap='jet_r', cbar=True)\n",
    "plt.title(f'Cohort {subset} Delong')\n",
    "plt.savefig(f'img/delong_each_cohort_{subset}.svg', bbox_inches = 'tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70332018",
   "metadata": {},
   "source": [
    "# AI 辅助"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19c16229",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "subset = 'test'\n",
    "ALL_results = pd.read_csv(AI_file)\n",
    "ALL_results.columns = ['ID', '-0', 'AI']\n",
    "human_results = pd.read_csv(AI_assistant_file)\n",
    "model_names = ['AI'] + list(human_results.columns[1:])\n",
    "ALL_results = pd.merge(pd.merge(ALL_results, human_results, on='ID', how='inner'), label_data_)\n",
    "\n",
    "ALL_results = ALL_results.dropna(axis=1)\n",
    "ALL_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "778195f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.markers import MarkerStyle\n",
    "\n",
    "pred_column = [f'{task}-0', f'{task}-1']\n",
    "gt = [np.array(ALL_results[task]) for d in model_names]\n",
    "pred_train = [np.array(ALL_results[d]) for d in model_names]\n",
    "okcomp.comp1.draw_roc(gt, pred_train, labels=model_names, title=f\"Model AUC\", auto_point=True)\n",
    "plt.savefig(f'img/Human_auc_AI.svg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a326f4a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.metrics import analysis_pred_binary\n",
    "metric = []\n",
    "for mname, y, score in zip(model_names, gt, pred_train):\n",
    "    # 计算验证集指标\n",
    "    acc, auc, ci, tpr, tnr, ppv, npv, precision, recall, f1, thres = analysis_pred_binary(y, score, use_youden=False)\n",
    "    ci = f\"{ci[0]:.4f} - {ci[1]:.4f}\"\n",
    "    metric.append((mname, acc, auc, ci, tpr, tnr, ppv, npv, precision, recall, f1, thres, f\"With AI\"))\n",
    "b = pd.DataFrame(metric, index=None, columns=['Signature', 'Accuracy', 'AUC', '95% CI', 'Sensitivity', 'Specificity', \n",
    "                                          'PPV', 'NPV', 'Precision', 'Recall', 'F1','Threshold', 'Cohort'])\n",
    "b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf3efb69",
   "metadata": {},
   "outputs": [],
   "source": [
    "b['group'] = b['Signature'].map(lambda x: x if x == 'AI' else x[:-1])\n",
    "ba = b.groupby('group').mean().reset_index()\n",
    "ba['Cohort'] = 'With AI'\n",
    "ba"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5e26328",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.delong import delong_roc_test\n",
    "from onekey_algo.custom.components.comp1 import draw_matrix\n",
    "\n",
    "delong = []\n",
    "delong_columns = []\n",
    "this_delong = []\n",
    "plt.figure(figsize=(15, 12))\n",
    "cm = np.zeros((len(model_names), len(model_names)))\n",
    "for i, mni in enumerate(model_names):\n",
    "    for j, mnj in enumerate(model_names):\n",
    "        if i <= j:\n",
    "            cm[i][j] = np.nan\n",
    "        else:\n",
    "            cm[i][j] = delong_roc_test(ALL_results[task], ALL_results[mni], ALL_results[mnj])[0][0]\n",
    "cm = pd.DataFrame(cm[1:, :-1], index=model_names[1:], columns=model_names[:-1])\n",
    "draw_matrix(cm, annot=True, cmap='jet_r', cbar=True)\n",
    "plt.title(f'Cohort {subset} Delong')\n",
    "plt.savefig(f'img/delong_each_cohort_{subset}_AI.svg', bbox_inches = 'tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ace79ec",
   "metadata": {},
   "source": [
    "# 辅助前后的变化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c379d18c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "c = pd.concat([a, b], axis=0)\n",
    "c = c[(c['Signature'] != 'AI')]\n",
    "\n",
    "# sns.catplot(data=titanic, x=\"sex\", y=\"survived\", hue=\"class\", kind=\"point\")\n",
    "for m in ['AUC', 'Accuracy', 'Sensitivity', 'Specificity']:\n",
    "    ax = sns.catplot(\n",
    "        data=c.reset_index(), kind=\"point\",\n",
    "        x=\"Cohort\", y=m,\n",
    "        hue=\"group\"\n",
    "    )\n",
    "    ax.set_xlabels(visible=False)# .set_visible(False)\n",
    "    plt.ylim(0.5 if m != 'Sensitivity' else 0.4)\n",
    "    from matplotlib.ticker import FormatStrFormatter\n",
    "    ax.ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))\n",
    "    plt.savefig(f'img\\human_boost_{m}.svg', bbox_inches='tight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92a5fb7c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
